# coding: utf-8

# -------------------------------------------------------------------------------
# Name:         mysql_client.py
# Description:  Mysql客户端
# Author:       XiangjunZhao
# EMAIL:        2419352654@qq.com
# Date:         2020/3/5 13:26
# -------------------------------------------------------------------------------
import logging

import pymysql
from DBUtils.PooledDB import PooledDB
from sshtunnel import SSHTunnelForwarder

logger = logging.getLogger(__name__)


class MysqlClient(object):
    """
    Mysql客户端
    """

    def __init__(self, host=None, port=3306, user=None, password=None, db=None, charset='utf8', ssh_host=None,
                 ssh_username=None, ssh_password=None, use_ssh_tunnel=False):
        """
        数据库工具类初始化方法
        Args:
            host: 数据库地址
            port: 数据库商品
            user: 数据库用户名
            password: 数据库密码
            db: 数据库名
            charset: 数据库编码
            ssh_host: ssh地址
            ssh_username: ssh用户名
            ssh_password: ssh密码
            use_ssh_tunnel: 使用ssh隧道
        """
        self.host = host
        self.port = port
        self.user = user
        self.password = password
        self.db = db
        self.charset = charset
        self.ssh_host = ssh_host
        self.ssh_username = ssh_username
        self.ssh_password = ssh_password
        self.use_ssh_tunnel = use_ssh_tunnel
        self.tunnel = None
        try:
            if self.use_ssh_tunnel:
                self.tunnel = SSHTunnelForwarder(
                    ssh_address_or_host=(self.ssh_host, 22),
                    ssh_username=self.ssh_username,
                    ssh_password=self.ssh_password,
                    remote_bind_address=(self.host, self.port))
                self.tunnel.start()

                self.db_pool = PooledDB(pymysql, mincached=1, maxcached=5, host='127.0.0.1',
                                        port=self.tunnel.local_bind_port, user=self.user, passwd=self.password,
                                        db=self.db, charset=self.charset)
            else:
                self.db_pool = PooledDB(pymysql, mincached=1, maxcached=5, host=self.host, port=self.port,
                                        user=self.user, passwd=self.password, db=self.db, charset=self.charset,
                                        connect_timeout=60)
        except Exception as e:
            if self.tunnel is not None:
                self.tunnel.close()
            logger.error('Mysql数据库工具初始化异常，原因：{}'.format(str(e)))

    def get_conn(self):
        """
        获取数据库连接
        Returns:

        """
        try:
            return self.db_pool.connection()
        except Exception as e:
            logger.error('获取Mysql数据库连接异常，原因：{}'.format(str(e)))

    def close_conn(self, conn):
        """
        关闭数据库连接
        Args:
            conn: 数据库连接

        Returns:

        """
        if conn:
            try:
                conn.close()
            except pymysql.Error as e:
                logger.error('关闭Mysql数据库连接异常，原因：{}'.format(str(e)))

    def close(self):
        """
        断开连接
        Returns:

        """
        try:
            self.db_pool.close()
        except Exception as e:
            logger.error('Mysql数据库断开连接异常，原因：{}'.format(str(e)))
        finally:
            if self.tunnel is not None:
                self.tunnel.close()

    def query_one(self, conn=None, sql=None):
        """
        查询一条数据
        Args:
            conn: 数据库连接
            sql: 查询sql

        Returns:

        """
        # 获取数据库游标
        with conn.cursor() as cursor:
            # 执行查询sql语句
            cursor.execute(sql)
            # 获取sql查询的列名
            desc = [item[0] for item in cursor.description]
            # 获取sql查询的结果
            data = cursor.fetchone()
            return dict(zip(desc, data)) if data else dict()

    def query_many(self, conn=None, sql=None):
        """
        查询所有数据
        Args:
            conn: 数据库连接
            sql: 查询sql

        Returns:

        """
        # 获取数据库游标
        with conn.cursor() as cursor:
            # 执行查询sql语句
            cursor.execute(sql)
            # 获取sql查询的列名
            desc = [item[0] for item in cursor.description]
            # 获取sql查询的结果
            data = cursor.fetchall()
            return [dict(zip(desc, item)) for item in data] if data else list()

    def execute_one(self, conn=None, sql=None):
        """
        执行单条 插入、更新、删除sql
        Args:
            conn: 数据库连接
            sql: 插入、更新、删除sql

        Returns:

        """
        # 获取数据库游标
        with conn.cursor() as cursor:
            try:
                cursor.execute(sql)
            except pymysql.Error as e:
                # 执行插入、更新、删除sql时发生异常，事务回滚
                conn.rollback()
                logger.error('Mysql数据库执行SQL异常，原因：{}'.format(str(e)))
            else:
                # 成功执行插入、更新、删除sql后，提交事务
                conn.commit()

    def execute_many(self, conn=None, sql=None, datas=None):
        """
        执行单条 插入、更新、删除sql
        Args:
            conn: 数据库连接
            sql: 插入、更新、删除sql
            datas: 插入、更新、删除sql的参数，参数格式为元组列表，示例：[(1,2,3,4),(1,2,3,4)]

        Returns:

        """
        # 获取数据库游标
        with conn.cursor() as cursor:
            try:
                cursor.executemany(sql, datas)
            except pymysql.Error as e:
                # 执行插入、更新、删除sql时发生异常，事务回滚
                conn.rollback()
                logger.error('Mysql数据库执行SQL异常，原因：{}'.format(str(e)))
            else:
                # 成功执行插入、更新、删除sql后，提交事务
                conn.commit()

    def change_db(self, conn=None, db=''):
        """
        切换数据库
        Args:
            conn: 数据库连接
            db: 待切换的数据库

        Returns:

        """
        # 获取数据库游标
        with conn.cursor() as cursor:
            try:
                sql = 'use {db};'.format(db=db)
                cursor.execute(sql)
            except pymysql.Error as e:
                conn.rollback()
                logger.error('Mysql数据库执行SQL异常，原因：{}'.format(str(e)))
            else:
                # 成功执行sql后，提交事务
                conn.commit()
